package com.za.plugin.transfer.form.insertupdate;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import org.apache.ibatis.mapping.ParameterMapping;

import java.util.*;
import java.util.stream.Collectors;

public class Update2StyleTransfer implements StyleTransfer {
    @Override
    public boolean isSupport(String sql) {
        return sql.trim().startsWith("update");
    }

    @Override
    public String transfer(String sql, String tableName, List<String> autoIncrProperties, List<ParameterMapping> parameterMappingsCopy, List<ParameterMapping> parameterMappings, Set<String> pKs, Map<String, List<String>> pkAndUniqueKeys) {

        if (com.za.plugin.util.StrUtil.getCount(sql, "update ") > 1) {
            parameterMappings.addAll(parameterMappingsCopy);
            return sql;
        }
        boolean isSkip = isSkip(sql, autoIncrProperties);
        if (isSkip) {
            parameterMappings.addAll(parameterMappingsCopy);
            return sql;
        }


        MySqlStatementParser parser = new MySqlStatementParser(sql.toLowerCase());
        StringBuilder sb = new StringBuilder();
        sb.append("update ");
        SQLStatement sqlStatement = parser.parseStatement();
        if (sqlStatement instanceof MySqlUpdateStatement) {

            int count = com.za.plugin.util.StrUtil.getCount(sql.toLowerCase(), "update ");

            MySqlUpdateStatement stat = (MySqlUpdateStatement) sqlStatement;
            List<SQLUpdateSetItem> items = stat.getItems();
            List<SQLUpdateSetItem> columns = new ArrayList<>(items);
            sb.append(SQLUtils.toMySqlString(stat.getTableSource()))
                    .append(" set ");
            List<Integer> skipIdList = new ArrayList<>();
            int i = 0;
            for (SQLUpdateSetItem column : columns) {
                if (isContains(SQLUtils.toMySqlString(column.getColumn()), autoIncrProperties)) {
                    skipIdList.add(i);
                } else {
                    sb.append(SQLUtils.toMySqlString(column.getColumn())).append(" = ")
                            .append(SQLUtils.toMySqlString(column.getValue())).append(",");
                }
                if (SQLUtils.toMySqlString(items.get(i)).contains("?")) {
                    i++;
                }
            }
            int size = parameterMappingsCopy.size() / count, start = 0;
            List<ParameterMapping> parameterMappingsTmp = parameterMappingsCopy.subList(start, size);
            while (CollectionUtil.isNotEmpty(parameterMappingsTmp)) {
                for (int k = 0; k < parameterMappingsTmp.size(); k++) {
                    if (!skipIdList.contains(k)) {
                        parameterMappings.add(parameterMappingsTmp.get(k));
                    }
                }
                start += size;
                if (start < parameterMappingsCopy.size()) {
                    parameterMappingsTmp = parameterMappingsCopy.subList(start, size);
                }
            }

            if (sb.charAt(sb.length() - 1) == ',') {
                sb.deleteCharAt(sb.length() - 1);
            }
            sb.append(" where ").append(SQLUtils.toMySqlString(stat.getWhere())).append(";");
            String origin = sb.toString();
            if (count > 1) {
                for (int j = 0; j < count - 1; j++) {
                    sb.append(origin);
                }
            }
        }
        return sb.toString();
    }

    private boolean isSkip(String sql, List<String> autoIncrProperties) {
        MySqlStatementParser parser = new MySqlStatementParser(sql.toLowerCase());
        SQLStatement sqlStatement = parser.parseStatement();
        if (sqlStatement instanceof MySqlUpdateStatement) {
            MySqlUpdateStatement stat = (MySqlUpdateStatement) sqlStatement;
            List<SQLUpdateSetItem> items = stat.getItems();
            if (CollectionUtil.isNotEmpty(items)) {
                List<String> columns = items.stream().filter(Objects::nonNull).map(SQLUpdateSetItem::getColumn)
                        .map(SQLUtils::toMySqlString)
                        .collect(Collectors.toList());
                boolean hasAtoIncr = false;
                for (String column : columns) {
                    if (isContains(column, autoIncrProperties)) {
                        hasAtoIncr = true;
                    }
                }
                if (!hasAtoIncr) {
                    return true;
                }

            } else {
                return true;
            }
        }
        return false;
    }

    private static int getQuestionMarkCnt(String str) {
        if (StrUtil.isBlank(str)) {
            return 0;
        }
        int count = 0, length = str.length();
        for (int i = 0; i < length; i++) {
            if (str.charAt(i) == '?') {
                count++;
            }
        }
        return count;
    }

    private static boolean isContains(String column, List<String> set) {
        column = column.replace(" ", "");
        if (set.contains(column)) {
            return true;
        }
        for (String s : set) {
            if (column.contains("." + s)) {
                return true;
            }
        }
        return false;
    }
}
